{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 图像分类数据集\n",
    ":label:`sec_fashion_mnist`\n",
    "\n",
    "目前广泛使用的图像分类数据集之一是 MNIST 数据集 :cite:`LeCun.Bottou.Bengio.ea.1998`。虽然它是很不错的基准数据集，但按今天的标准，即使是简单的模型也能达到95%以上的分类准确率，因此不适合区分强模型和弱模型。如今，MNIST更像是一个健全检查，而不是一个基准。\n",
    "为了提高难度，我们将在接下来的章节中讨论在2017年发布的性质相似但相对复杂的Fashion-MNIST数据集 :cite:`Xiao.Rasul.Vollgraf.2017`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/StopWatch.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.basicdataset.cv.classification.*;\n",
    "import java.awt.image.BufferedImage;\n",
    "import java.awt.Graphics2D;\n",
    "import java.awt.Graphics;\n",
    "import java.awt.Color;\n",
    "import java.awt.Dimension;\n",
    "import java.awt.FlowLayout;\n",
    "import java.awt.Component;\n",
    "import javax.swing.*;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取数据集\n",
    "\n",
    "就像`MNIST`，我们可以使用 DJL `ai.djl.basicdataset` 中的 `FashionMnist` 类来下载并读取到内存中。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 5,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "int batchSize = 256;\n",
    "boolean randomShuffle = true;\n",
    "\n",
    "FashionMnist mnistTrain = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TRAIN)\n",
    "        .setSampling(batchSize, randomShuffle)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "FashionMnist mnistTest = FashionMnist.builder()\n",
    "        .optUsage(Dataset.Usage.TEST)\n",
    "        .setSampling(batchSize, randomShuffle)\n",
    "        .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "        .build();\n",
    "\n",
    "mnistTrain.prepare();\n",
    "mnistTest.prepare();\n",
    "\n",
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fashion-MNIST 由 10 个类别的图像组成，每个类别由训练数据集中的 6000 张图像和测试数据集中的 1000 张图像组成。*测试数据集*（test dataset）不会用于训练，只用于评估模型性能。训练集和测试集分别包含 60000 和 10000 张图像。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 9,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "System.out.println(mnistTrain.size());\n",
    "System.out.println(mnistTest.size());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fashion-MNIST中包含的10个类别分别为t-shirt（T恤）、trouser（裤子）、pullover（套衫）、dress（连衣裙）、coat（外套）、sandal（凉鞋）、shirt（衬衫）、sneaker（运动鞋）、bag（包）和ankle boot（短靴）。以下函数用于在数字标签索引及其文本名称之间进行转换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 14,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "// Saved in the FashionMnist class for later use\n",
    "public String[] getFashionMnistLabels(int[] labelIndices) {\n",
    "    String[] textLabels = {\"t-shirt\", \"trouser\", \"pullover\", \"dress\", \"coat\",\n",
    "                   \"sandal\", \"shirt\", \"sneaker\", \"bag\", \"ankle boot\"};\n",
    "    String[] convertedLabels = new String[labelIndices.length];\n",
    "    for (int i = 0; i < labelIndices.length; i++) {\n",
    "        convertedLabels[i] = textLabels[labelIndices[i]];\n",
    "    }\n",
    "    return convertedLabels;\n",
    "}\n",
    "\n",
    "public String getFashionMnistLabel(int labelIndice) {\n",
    "    String[] textLabels = {\"t-shirt\", \"trouser\", \"pullover\", \"dress\", \"coat\",\n",
    "                   \"sandal\", \"shirt\", \"sneaker\", \"bag\", \"ankle boot\"};\n",
    "    return textLabels[labelIndice];\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们现在可以创建一个函数来可视化这些样本。\n",
    "\n",
    "下面的代码只是为了帮助直观地理解数据, 你不需要太关注可视化的细节。我们读取了许多数据点并将它们的 `RGB` 值从 0-255 转换为 0-1 之间。然后我们将颜色以灰度的形式，将其与标签一起显示出来。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 16,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "public class ImagePanel extends JPanel {\n",
    "    int SCALE;\n",
    "    BufferedImage img;\n",
    "\n",
    "    public ImagePanel() {\n",
    "        this.SCALE = 1;\n",
    "    }\n",
    "    public ImagePanel(int scale, BufferedImage img) {\n",
    "        this.SCALE = scale;\n",
    "        this.img = img;\n",
    "    }\n",
    "    @Override\n",
    "    protected void paintComponent(Graphics g) {\n",
    "        Graphics2D g2d = (Graphics2D)g;\n",
    "        g2d.scale(SCALE, SCALE);\n",
    "        g2d.drawImage(this.img, 0, 0, this);\n",
    "    }\n",
    "}\n",
    "\n",
    "public class Container extends JPanel {\n",
    "    public Container(String label) {\n",
    "        setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));\n",
    "        JLabel l = new JLabel(label, JLabel.CENTER);\n",
    "        l.setAlignmentX(Component.CENTER_ALIGNMENT);\n",
    "        add(l);\n",
    "    }\n",
    "    public Container(String trueLabel, String predLabel) {\n",
    "        setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));\n",
    "        JLabel l = new JLabel(trueLabel, JLabel.CENTER);\n",
    "        l.setAlignmentX(Component.CENTER_ALIGNMENT);\n",
    "        add(l);\n",
    "        JLabel l2 = new JLabel(predLabel, JLabel.CENTER);\n",
    "        l2.setAlignmentX(Component.CENTER_ALIGNMENT);\n",
    "        add(l2);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Saved in the FashionMnistUtils class for later use\n",
    "public void showImages(ArrayDataset dataset,\n",
    "                       int number, int WIDTH, int HEIGHT, int SCALE,\n",
    "                       NDManager manager)\n",
    "    throws IOException, TranslateException {\n",
    "    // Plot a list of images\n",
    "    JFrame frame = new JFrame(\"Fashion Mnist\");\n",
    "    for (int record = 0; record < number; record++) {\n",
    "        NDArray X = dataset.get(manager, record).getData().get(0).squeeze(-1);\n",
    "        int y = (int)dataset.get(manager, record).getLabels().get(0).getFloat();\n",
    "        BufferedImage img = new BufferedImage(WIDTH, HEIGHT, BufferedImage.TYPE_BYTE_GRAY);\n",
    "        Graphics2D g = (Graphics2D) img.getGraphics();\n",
    "        for(int i = 0; i < WIDTH; i++) {\n",
    "            for(int j = 0; j < HEIGHT; j++) {\n",
    "                float c = X.getFloat(j, i) / 255;  // scale down to between 0 and 1\n",
    "                g.setColor(new Color(c, c, c)); // set as a gray color\n",
    "                g.fillRect(i, j, 1, 1);\n",
    "            }\n",
    "        }\n",
    "        JPanel panel = new ImagePanel(SCALE, img);\n",
    "        panel.setPreferredSize(new Dimension(WIDTH * SCALE, HEIGHT * SCALE));\n",
    "        JPanel container = new Container(getFashionMnistLabel(y));\n",
    "        container.add(panel);\n",
    "        frame.getContentPane().add(container);\n",
    "    }\n",
    "    frame.getContentPane().setLayout(new FlowLayout());\n",
    "    frame.pack();\n",
    "    frame.setVisible(true);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以下是训练数据集中前几个样本的图像及其相应的标签（文本形式）。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 19,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "final int SCALE = 4;\n",
    "final int WIDTH = 28;\n",
    "final int HEIGHT = 28;\n",
    "\n",
    "/* Uncomment the following line and run to display images.\n",
    "   It will open in another window. */\n",
    "// showImages(mnistTrain, 18, WIDTH, HEIGHT, SCALE, manager);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Fashion Mnist labels.](https://d2l-java-resources.s3.amazonaws.com/img/fashion_mnist_labels.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取小批量\n",
    "\n",
    "为了使我们在读取训练集和测试集时更容易，我们使用 `getData(manager)`。回顾一下，在每次迭代中，`getData(manager)` 每次都会读取一小批量数据，大小为 `batchSize`。我们可以用 `getData()` 和 `getLabels()` 来得到`x`和`y`。\n",
    "\n",
    "让我们看一下读取训练数据所需的时间。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "origin_pos": 27,
    "tab": [
     "mxnet"
    ]
   },
   "outputs": [],
   "source": [
    "StopWatch stopWatch = new StopWatch();\n",
    "stopWatch.start();\n",
    "for (Batch batch : mnistTrain.getData(manager)) {\n",
    "    NDArray x = batch.getData().head();\n",
    "    NDArray y = batch.getLabels().head();\n",
    "}\n",
    "System.out.println(String.format(\"%.2f sec\", stopWatch.stop()));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们现在已经准备好在下面的章节中使用Fashion-MNIST数据集。\n",
    "\n",
    "## 小结\n",
    "\n",
    "* Fashion-MNIST是一个服装分类数据集，由10个类别的图像组成。我们将在后续章节中使用此数据集来评估各种分类算法。\n",
    "* 我们将高度$h$像素，宽度$w$像素图像的形状记为$h \\times w$或($h$, $w$)。\n",
    "* 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器，利用高性能计算来避免减慢训练过程。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 减少 `batchSize`（如减少到 1）是否会影响读取性能？\n",
    "1. 数据迭代器的性能非常重要。你认为当前的实现足够快吗？探索各种选择来改进它。\n",
    "1. 查阅框架的在线API文档。还有哪些其他数据集可用？\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
